import torch
import json
import tqdm
import os
import jsonlines

import sys
sys.path.append(".")
from rob_baseline import *

from transformers import AutoTokenizer
from utils import *

def pad_tokens(input_ids,block_size,tokenizer):
    orig_len = len(input_ids)
    if len(input_ids) >= block_size:
        input_ids = input_ids[0:block_size]
        orig_len = block_size
    elif len(input_ids) < block_size:
        input_ids = input_ids+[tokenizer.pad_token_id]*(block_size-len(input_ids))
    att_mask_tokens = [1]*orig_len + [0]*(len(input_ids)-orig_len)
    assert len(input_ids) == len(att_mask_tokens)
    return input_ids , att_mask_tokens

class KTLDataset(torch.utils.data.Dataset):
    def __init__(self, file_path, tokenizer,hparams, max_c=30, max_q=30, max_a=20):
        self.file_path = file_path
        with jsonlines.open(file_path) as f:
            self.lines =[]
            for ix,x in tqdm.tqdm(enumerate(f), "Reading Source"):
                if hparams.tiny and ix>100:
                    break
                self.lines.append(x)
        self.len = len(self.lines)
        
        self.max_q = max_q
    
        if hparams.use_context:
            max_a = max_a+max_c+max_q
            max_c = max_c+max_q
        
        self.max_c = max_c
        self.max_a = max_a
        self.tokenizer = tokenizer        
        self.hparams=hparams
        self.mlm_maxlen = self.hparams.mlm_maxlen

    def mlm_mapper(self, sample):
        c = sample["ctxt"]
        q = sample["qtxt"]
        a = sample["atxt"]
        label = c+q+a
        cmasked = [self.tokenizer.mask_token_id]*len(c) + q + a
        qmasked = c + [self.tokenizer.mask_token_id]*len(q) + a
        amasked = c + q + [self.tokenizer.mask_token_id]*len(q) 
        label  ,labmask = pad_tokens(label,self.mlm_maxlen,self.tokenizer)
        cmasked,catmask = pad_tokens(cmasked,self.mlm_maxlen,self.tokenizer)
        qmasked,qatmask = pad_tokens(qmasked,self.mlm_maxlen,self.tokenizer)
        amasked,aatmask = pad_tokens(amasked,self.mlm_maxlen,self.tokenizer)

        return torch.tensor(cmasked), torch.tensor(catmask),\
               torch.tensor(qmasked), torch.tensor(qatmask),\
                   torch.tensor(amasked), torch.tensor(aatmask),\
                   torch.tensor(label), torch.tensor(labmask)
                

    def nce_mapper(self,sample):
        c,q,a = sample["ctxt"],sample["qtxt"],sample["atxt"]
        nc,nq,na = sample["nctxt"]+[c], sample["nqtxt"]+[q], sample["natxt"]+[a]
        c,catmask = pad_tokens(c,self.max_c,self.tokenizer)
        q,qatmask = pad_tokens(q,self.max_q,self.tokenizer)
        a,aatmask = pad_tokens(a,self.max_a,self.tokenizer)
        nc = [pad_tokens(x,self.max_c,self.tokenizer) for x in nc]
        nq = [pad_tokens(x,self.max_q,self.tokenizer) for x in nq]
        na = [pad_tokens(x,self.max_a,self.tokenizer) for x in na]
        ncinputids = [x[0] for x in nc]
        ncatmask = [x[1] for x in nc]
        nqinputids = [x[0] for x in nq]
        nqatmask = [x[1] for x in nq]
        nainputids = [x[0] for x in na]
        naatmask = [x[1] for x in na]

        return torch.tensor(c), torch.tensor(catmask),\
               torch.tensor(q), torch.tensor(qatmask),\
                   torch.tensor(a), torch.tensor(aatmask),\
                   torch.tensor(ncinputids), torch.tensor(ncatmask),\
                       torch.tensor(nqinputids), torch.tensor(nqatmask),\
                           torch.tensor(nainputids), torch.tensor(naatmask)
                    


    def __getitem__(self, index):
        sample = self.lines[index]
        if self.hparams.ktltype == "cmlm":
            output = self.mlm_mapper(sample)
        elif self.hparams.ktltype == "nce":
            output = self.nce_mapper(sample)
        return output

    def __len__(self):
        if self.hparams.tiny:
            return len(self.inputs[0:100])





class ValDataset(torch.utils.data.Dataset):
    def __init__(self, file_path, tokenizer,hparams, max_c=30, max_q=30, max_a=10):
        self.file_path = file_path
        
        self.inputs=[]

        instance_reader = INSTANCE_READERS[hparams.valtype]()

        with open(file_path) as f_in:
            for line in tqdm.tqdm(f_in):
                fields = json.loads(line.strip())
                context, question, label, choices, context_with_choices = \
                    instance_reader.fields_to_instance(fields=fields)
                row = {"ctxt":context,"qtxt":question,"atxts":choices,"label":label}
                self.inputs.append(row)

        self.len = len(self.inputs)
        
        self.max_q = max_q
    
        if hparams.use_context:
            max_a = max_a+max_c+max_q
            max_c = max_c+max_q
        
        self.max_c = max_c
        self.max_a = max_a
        self.tokenizer = tokenizer        
        self.hparams=hparams
        
        self.mlm_maxlen = self.hparams.mlm_maxlen

    def mlm_mapper(self,sample):
        context,question,choices = sample["ctxt"],sample["qtxt"],sample["atxts"]
        is_anli = self.hparams.valtype=="anli"
        label_tokens = []
        context_masked = []
        question_masked = []
        answer_masked = []
        tokenizer = self.tokenizer
        if len(context)!=0:
                context_tokens = tokenizer.encode(context)
        else:
                context_tokens = []
        question_tokens = tokenizer.encode(question)
        context_mask = tokenizer.convert_tokens_to_ids([tokenizer.mask_token]*len(context_tokens))
        question_mask = tokenizer.convert_tokens_to_ids([tokenizer.mask_token]*len(question_tokens))

        choice_tokens = []
        choice_masks = []
        for choice in choices:
            choice_token = tokenizer.encode(choice)
            choice_mask = tokenizer.convert_tokens_to_ids([tokenizer.mask_token]*len(choice_token))
            choice_tokens.append(choice_token)
            choice_masks.append(choice_mask)
        
        max_length = max([len(x) for x in choice_tokens])
        for choice_token, choice_mask in zip(choice_tokens,choice_masks):
            label_token = context_tokens+question_tokens+choice_token if not is_anli else context_tokens+choice_token+question_tokens
            context_masked_tokens = context_mask+question_tokens+choice_token if not is_anli else context_mask+choice_token+question_tokens
            question_masked_tokens = context_tokens+question_mask+choice_token if not is_anli else context_tokens+choice_token+question_mask
            choice_masked_tokens = context_tokens+question_tokens+choice_mask if not is_anli else context_tokens+choice_mask+question_tokens

            label_token,label_mask = pad_tokens(label_token,80,tokenizer)
            context_masked_tokens,_ = pad_tokens(context_masked_tokens,80,tokenizer)
            question_masked_tokens,_ = pad_tokens(question_masked_tokens,80,tokenizer)
            choice_masked_tokens,_ = pad_tokens(choice_masked_tokens,80,tokenizer)

            label_tokens.append(label_token)
            context_masked.append(context_masked_tokens)
            question_masked.append(question_masked_tokens)
            answer_masked.append(choice_masked_tokens)
        
        label=sample["label"]

        return  torch.tensor(context_masked), torch.tensor(question_masked), torch.tensor(answer_masked), torch.tensor(label_tokens), torch.tensor(label)
                

    def nce_mapper(self,sample):
        tokenizer = self.tokenizer
        c,q,a = sample["ctxt"],sample["qtxt"],sample["atxts"]
        if len(c)!=0:
                context_tokens = tokenizer.encode(c)
        else:
                context_tokens = []
        c,catmask = pad_tokens(context_tokens,self.max_c,self.tokenizer)
        q,qatmask = pad_tokens(tokenizer.encode(q),self.max_q,self.tokenizer)
        na = [pad_tokens(tokenizer.encode(x),self.max_a,self.tokenizer) for x in a]
        nainputids = [x[0] for x in na]
        naatmask = [x[1] for x in na]
        # print(c,q,a)
        label=sample["label"]

        return torch.tensor(c), torch.tensor(catmask),\
               torch.tensor(q), torch.tensor(qatmask),\
               torch.tensor(nainputids), torch.tensor(naatmask), torch.tensor(label)

    def __getitem__(self, index):
        sample = self.inputs[index]
        if self.hparams.ktltype == "cmlm":
            output = self.mlm_mapper(sample)
        elif self.hparams.ktltype == "nce":
            output = self.nce_mapper(sample)
        return output

    def __len__(self):
        if self.hparams.tiny:
            return len(self.inputs[0:100])
        return len(self.inputs)


if __name__ == '__main__':
    hparams = get_argument_parser().parse_args()  
    tokenizer = AutoTokenizer.from_pretrained("roberta-base",fast=True)
    ktldataset = KTLDataset(file_path=hparams.train_file,tokenizer=tokenizer,hparams=hparams)

    for row in ktldataset:
        for x in row:
            print(x.shape)

    valdataset = ValDataset(file_path=hparams.val_file,tokenizer=tokenizer,hparams=hparams)
    for row in valdataset:
        print(row)
        for x in row:
            print(x.shape)

